// Coded using Processing www.processing.org
// Trained using sparse Gaussian mutations
// The cost landscape is so smooth that
// back-propagation would likely work
// well and be faster.
import java.util.concurrent.*;
final int nExamples=100;
final int nThreads=2;
final int vecSize=256;
SWNet8 net=new SWNet8(vecSize, 8);
Mutator mut=new Mutator(1000, net.scale,0.001);
CyclicBarrier cb=new CyclicBarrier(nThreads, new CBRun());
float[][] exampleIn=new float[nExamples][vecSize];
float[][] exampleTarget=new float[nExamples][vecSize];
volatile boolean shouldRun;
volatile boolean shouldRunWorker;
volatile float childCost;
volatile float parentCost=Float.POSITIVE_INFINITY;
volatile int epoch;
void setup() {
  size(200, 200);
  for(int i=0;i<nExamples;i++){
    exampleIn[i][i]=1f;
    exampleTarget[i][i]=1f;
  }  
  startTraining();
}
void draw() {
  background(0);
  text("Epoch: "+epoch, 5, 20);
  text("Cost: "+parentCost, 5, 50);

}

void stopTraining() {
  shouldRun=false;
}

void startTraining() {
  mut.mutate(net.params); // Have the mutation ready for the threads
  int step=nExamples/nThreads;
  shouldRun=true;
  shouldRunWorker=true;
  for (int i=0; i<nThreads; i++) {
    new Worker(i*step, (i+1)*step).start();
  }
}

void setCost(float cost) {
  synchronized(this) {
    childCost+=cost;
  }
  try {
    cb.await();
  }
  catch(Exception e) {
    println(e);
  }
}

final class CBRun implements Runnable {
  public void run() {
    epoch++;
    if (childCost<parentCost) {
      parentCost=childCost;
    } else {
      mut.undo(net.params);
    }
    childCost=0f;
    if (shouldRun) {
      mut.mutate(net.params);     
    } else {
      shouldRunWorker=false;
    }
  }
}
final class Worker extends Thread {
  final int low;
  final int high;
  final float[] work=new float[vecSize];;
  Worker(int low, int high) {
    this.low=low;
    this.high=high;
  }
  public void run() {
    while (shouldRunWorker) {
      float cost=0f;
      for (int i=low; i<high; i++) {
        net.recall(work,exampleIn[i]);
        cost+=costL2(exampleTarget[i],work);
      }
      setCost(cost);
    }
  }
}

final class SWNet8 {
  final int vecLen;
  final int depth;
  final float scale;
  final float[] params;
  final float[] flips;
  SWNet8(int vecLen, int depth) {
    this.vecLen = vecLen;
    this.depth = depth;
    scale = 1f / sqrt(vecLen>>>3);
    params = new float[16 * vecLen * depth];
    flips = new float[vecLen];
    float sc=1f/sqrt(vecLen);
    for (int i = 0; i < vecLen; i++) {
      flips[i] = sin(i * 1.895) < 0 ? -sc : sc;
    }
    for (int i = 0, j=0; i < params.length; i += 16) {
      params[i+j] = scale;
      params[i+j + 8] = scale;
      j=(j+1)&7;
    }
  }
  void recall(float[] result, float[] input) {
    for (int i = 0; i < vecLen; i++) {
      result[i] = input[i] * flips[i];
    }
    int paIdx = 0;
    wht(result);
    for (int i = 0; i < depth; i++) {
      for (int j = 0; j < vecLen; j += 8) {
        float x, a, b, c, d, e, f, g, h;
        x= result[j];
        if (x < 0f) {
          a = x * params[paIdx];
          b = x * params[paIdx + 1];
          c = x * params[paIdx + 2];
          d = x * params[paIdx + 3];
          e = x * params[paIdx + 4];
          f = x * params[paIdx + 5];
          g = x * params[paIdx + 6];
          h = x * params[paIdx + 7];
        } else {
          a = x * params[paIdx + 8];
          b = x * params[paIdx + 9];
          c = x * params[paIdx + 10];
          d = x * params[paIdx + 11];
          e = x * params[paIdx + 12];
          f = x * params[paIdx + 13];
          g = x * params[paIdx + 14];
          h = x * params[paIdx + 15];
        }
        paIdx += 16;
        x = result[j + 1];
        if (x < 0f) {
          a += x * params[paIdx];
          b += x * params[paIdx + 1];
          c += x * params[paIdx + 2];
          d += x * params[paIdx + 3];
          e += x * params[paIdx + 4];
          f += x * params[paIdx + 5];
          g += x * params[paIdx + 6];
          h += x * params[paIdx + 7];
        } else {
          a += x * params[paIdx + 8];
          b += x * params[paIdx + 9];
          c += x * params[paIdx + 10];
          d += x * params[paIdx + 11];
          e += x * params[paIdx + 12];
          f += x * params[paIdx + 13];
          g += x * params[paIdx + 14];
          h += x * params[paIdx + 15];
        }
        paIdx += 16;
        x = result[j + 2];
        if (x < 0f) {
          a += x * params[paIdx];
          b += x * params[paIdx + 1];
          c += x * params[paIdx + 2];
          d += x * params[paIdx + 3];
          e += x * params[paIdx + 4];
          f += x * params[paIdx + 5];
          g += x * params[paIdx + 6];
          h += x * params[paIdx + 7];
        } else {
          a += x * params[paIdx + 8];
          b += x * params[paIdx + 9];
          c += x * params[paIdx + 10];
          d += x * params[paIdx + 11];
          e += x * params[paIdx + 12];
          f += x * params[paIdx + 13];
          g += x * params[paIdx + 14];
          h += x * params[paIdx + 15];
        }
        paIdx += 16;
        x = result[j + 3];
        if (x < 0f) {
          a += x * params[paIdx];
          b += x * params[paIdx + 1];
          c += x * params[paIdx + 2];
          d += x * params[paIdx + 3];
          e += x * params[paIdx + 4];
          f += x * params[paIdx + 5];
          g += x * params[paIdx + 6];
          h += x * params[paIdx + 7];
        } else {
          a += x * params[paIdx + 8];
          b += x * params[paIdx + 9];
          c += x * params[paIdx + 10];
          d += x * params[paIdx + 11];
          e += x * params[paIdx + 12];
          f += x * params[paIdx + 13];
          g += x * params[paIdx + 14];
          h += x * params[paIdx + 15];
        }
        paIdx += 16;
        x = result[j + 4];
        if (x < 0f) {
          a += x * params[paIdx];
          b += x * params[paIdx + 1];
          c += x * params[paIdx + 2];
          d += x * params[paIdx + 3];
          e += x * params[paIdx + 4];
          f += x * params[paIdx + 5];
          g += x * params[paIdx + 6];
          h += x * params[paIdx + 7];
        } else {
          a += x * params[paIdx + 8];
          b += x * params[paIdx + 9];
          c += x * params[paIdx + 10];
          d += x * params[paIdx + 11];
          e += x * params[paIdx + 12];
          f += x * params[paIdx + 13];
          g += x * params[paIdx + 14];
          h += x * params[paIdx + 15];
        }
        paIdx += 16;
        x = result[j + 5];
        if (x < 0f) {
          a += x * params[paIdx];
          b += x * params[paIdx + 1];
          c += x * params[paIdx + 2];
          d += x * params[paIdx + 3];
          e += x * params[paIdx + 4];
          f += x * params[paIdx + 5];
          g += x * params[paIdx + 6];
          h += x * params[paIdx + 7];
        } else {
          a += x * params[paIdx + 8];
          b += x * params[paIdx + 9];
          c += x * params[paIdx + 10];
          d += x * params[paIdx + 11];
          e += x * params[paIdx + 12];
          f += x * params[paIdx + 13];
          g += x * params[paIdx + 14];
          h += x * params[paIdx + 15];
        }
        paIdx += 16;
        x = result[j + 6];
        if (x < 0f) {
          a += x * params[paIdx];
          b += x * params[paIdx + 1];
          c += x * params[paIdx + 2];
          d += x * params[paIdx + 3];
          e += x * params[paIdx + 4];
          f += x * params[paIdx + 5];
          g += x * params[paIdx + 6];
          h += x * params[paIdx + 7];
        } else {
          a += x * params[paIdx + 8];
          b += x * params[paIdx + 9];
          c += x * params[paIdx + 10];
          d += x * params[paIdx + 11];
          e += x * params[paIdx + 12];
          f += x * params[paIdx + 13];
          g += x * params[paIdx + 14];
          h += x * params[paIdx + 15];
        }
        paIdx += 16;
        x = result[j + 7];
        if (x < 0f) {
          a += x * params[paIdx];
          b += x * params[paIdx + 1];
          c += x * params[paIdx + 2];
          d += x * params[paIdx + 3];
          e += x * params[paIdx + 4];
          f += x * params[paIdx + 5];
          g += x * params[paIdx + 6];
          h += x * params[paIdx + 7];
        } else {
          a += x * params[paIdx + 8];
          b += x * params[paIdx + 9];
          c += x * params[paIdx + 10];
          d += x * params[paIdx + 11];
          e += x * params[paIdx + 12];
          f += x * params[paIdx + 13];
          g += x * params[paIdx + 14];
          h += x * params[paIdx + 15];
        }
        paIdx += 16;
        result[j] = a;
        result[j + 1] = b;
        result[j + 2] = c;
        result[j + 3] = d;
        result[j + 4] = e;
        result[j + 5] = f;
        result[j + 6] = g;
        result[j + 7] = h;
      }
      wht8(result);
    }
  }
}
// The width 8 neural layers provide 8 way connectivity which means
// you only need do a partial Walsh Hadamard transform to get full
// connectivity, as provided by wht8().
void wht8(float[] vec) {
  final int n = vec.length;
  int hs = 8;
  while (hs < n) {
    int i = 0;
    while (i < n) {
      final int j = i + hs;  // final here is good hint to hotspot
      while (i < j) {
        float a = vec[i];
        float b = vec[i + hs];
        vec[i] = a + b;
        vec[i + hs] = a - b;
        i += 1;
      }
      i += hs;
    }
    hs += hs;
  }
}
// Full Fast Walsh Hadamard Transform
void wht(float[] vec) {
  final int n = vec.length;
  for (int i=0; i<n; i+=8) {
    float a=vec[i];
    float b=vec[i+1];
    float c=vec[i+2];
    float d=vec[i+3];
    float e=vec[i+4];
    float f=vec[i+5];
    float g=vec[i+6];
    float h=vec[i+7];
    float t=a;
    a=a+b;
    b=t-b;
    t=c;
    c=c+d;
    d=t-d;
    t=a;
    a=a+c;
    c=t-c;
    t=b;
    b=b+d;
    d=t-d;
    t=e;
    e=e+f;
    f=t-f;
    t=g;
    g=g+h;
    h=t-h;
    t=e;
    e=e+g;
    g=t-g;
    t=f;
    f=f+h;
    h=t-h;
    t=a;
    a=a+e;
    e=t-e;
    t=b;
    b=b+f;
    f=t-f;
    t=c;
    c=c+g;
    g=t-g;
    t=d;
    d=d+h;
    h=t-h;
    vec[i]=a;
    vec[i+1]=b;
    vec[i+2]=c;
    vec[i+3]=d;
    vec[i+4]=e;
    vec[i+5]=f;
    vec[i+6]=g;
    vec[i+7]=h;
  }
  int hs = 8;
  while (hs < n) {
    int i = 0;
    while (i < n) {
      final int j = i + hs;  // final here is good hint to hotspot
      while (i < j) {
        float a = vec[i];
        float b = vec[i + hs];
        vec[i] = a + b;
        vec[i + hs] = a - b;
        i += 1;
      }
      i += hs;
    }
    hs += hs;
  }
}

// Sum of squared difference cost
float costL2(float[] tar, float[] vec) {
  float cost = 0;
  for (int i = 0; i < vec.length; i++) {
    float e = vec[i] - tar[i];
    cost += e * e;
  }
  return cost;
}

class Mutator {
  float[] previous;
  int[] pIdx;
  float limit;
  float noise;
  Rnd256 rng;

  Mutator(int size, float limit, float noise) {
    previous = new float[size];
    pIdx = new int[size];
    this.limit = limit;
    this.noise=noise;
    rng=new Rnd256();
  }
  void mutate(float[] vec) {
    for (int i = 0; i < previous.length; i++) {
      int rpos = rng.nextIntEx(vec.length);
      float v = vec[rpos];
      pIdx[i] = rpos;
      previous[i] = v;
      float m = limit * randomGaussian()*noise;
      float vm = v + m;
      if (vm >= this.limit) vm = v;
      if (vm <= -this.limit) vm = v;
      vec[rpos] = vm;
    }
  }
  void undo(float[] vec) {
    for (int i = previous.length - 1; i >= 0; i--) {
      vec[pIdx[i]] = previous[i];
    }
  }
}

final class Rnd256 {
  final long PHI = 0x9E3779B97F4A7C15L;
  private long s0, s1, s2, s3;
  Rnd256() {
    this(System.nanoTime());
  }
  Rnd256(long seed) {
    s0=staffordMix13(seed+PHI);
    s1=staffordMix13(seed+2*PHI);
    s2=staffordMix13(seed+3*PHI);
    s3=staffordMix13(seed+4*PHI);
  }
  long nextLong() {
    long result = s1;
    result = Long.rotateLeft(result + (result << 2), 7);
    result += result << 3;
    final long t = s1 << 17;
    s2 ^= s0;
    s3 ^= s1;
    s1 ^= s2;
    s0 ^= s3;
    s2 ^= t;
    s3 = Long.rotateLeft(s3, 45);
    return result;
  }
  float nextFloat() {
    return (nextLong() >>> 40) * 5.9604645E-8f;
  }
  boolean nextBoolean() {
    return nextLong() < 0;
  }
  int nextIntEx(int ex) {
    long r=(nextLong()>>>32)*ex;
    return (int)(r>>>32);
  }
  int nextIntInc(int inc) {
    long r=(nextLong()>>>32)*(inc+1L);
    return (int)(r>>>32);
  }
  float mutate() {
    int r=(int)nextLong();
    r &=0xbfff_ffff;
    r |=0x3000_0000;
    return Float.intBitsToFloat(r);
  }
  long staffordMix13(long z) {
    z = (z ^ (z >>> 30)) * 0xBF58476D1CE4E5B9L;
    z = (z ^ (z >>> 27)) * 0x94D049BB133111EBL;
    return z ^ (z >>> 31);
  }
}
